-
Notifications
You must be signed in to change notification settings - Fork 445
Add ChangeDetectionTask #2422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ChangeDetectionTask #2422
Conversation
I wonder if we should limit the scope to change between two timesteps and binary change - then we can use binary metrics and provide a template for the plot methods. I say this because this is the most common change detection task by a mile. Might also simplify the augmentations approach? Treating as a video sequence seems overkill. |
I'm personally okay with this, although @hfangcat has a recent work using multiple pre-event images that would be nice to support someday (could be a subclass if necessary).
Again, this would probably be fine as a starting point, although I would someday like to make all trainers support binary/multiclass/multilabel, e.g., #2219.
Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers: #2184
Agreed.
I actually like the video augmentations, but let me loop in the Kornia folks to get their opinion: @edgarriba @johnnv1
Correct, see #2382 for the big picture (I think I also sent you a recording of my presented plan). |
Can you try
@ashnair1 would this work directly with |
I will go ahead and make changes for this to be for binary change and two timesteps, sounds like a good starting point.
I tried this and it didn't get rid of the other dimension. I also looked into
I was going to add plotting in the trainer, but would you rather not then? What would this look like in the dataset? |
Perhaps there should even be a base class ChangeDetection and subclasses for BinaryChangeDetection etc? |
That's exactly what I'm trying to undo in #2219. |
We can copy-n-paste the
See |
I've updated this to now support only binary change with two timesteps. I still haven't been able to figure out how to make |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you resolve the merge conflicts so we can run the tests?
I'm going to need some help figuring out how to get Also, disregard my earlier comments about Kornia |
I see that in the automated pytest checks (in tests / minimum) there was a syntax error, but I don't see this issue locally. How do I resolve this? |
Couldn't reproduce either, might be a bug specific to older Python versions. Anyway, undid the change on the line it was complaining about, let's see if that fixes it. |
I noticed I also need to update tests/datasets/test_oscd.py. And I think we are removing tests/datamodules/test_oscd.py per #978 now that we have a change detection task, right? Also, how do you usually run prettier locally? I see there is an issue in the prettier check here but it isn't installed in the devcontainer unless I'm mistaken. |
Correct.
To be honest, I rarely run prettier locally. But here are the docs on how to do it: https://torchgeo.readthedocs.io/en/latest/user/contributing.html#linters I've also never used the devcontainer before. But it looks like there is a VS Code extension for prettier: https://marketplace.visualstudio.com/items?itemName=esbenp.prettier-vscode. Feel free to add that to the devcontainer in a separate PR if you want. |
@adamjstewart I think this is a remaining question to resolve on this PR, and I realized that this is already what |
@isaaccorley and @adamjstewart I think this is ready. Let me know what I should put in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lgtm. I believe it should be version 0.8 since this would be a new feature and that is the next milestone. Will wait for Adam's review though before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could add the denormalization for plotting, but as far as adding support for multiclass and multilabel, we specifically decided that this PR would just support binary change detection, so I'd like to keep it like that and someone else can extend it to those cases later. |
I could have a look at adding multiclass & multilabel support once this PR is merged. As I have done something similar already for my research. If you @adamjstewart would guide me a bit for my first PR to torchgeo on this one? |
@keves1 yes, let's add denormalizing and switch from if-statements to match-statements. I'll let @hkristen take a stab at non-binary change detection in a follow-up PR after this is merged. I know @hfangcat is also interested in change detection for 3+ images, I'm not sure if that would be a feature to add to ChangeDetectionTask or SemanticSegmentationTask. I would like to merge this PR soon as I think it's almost ready, we can iterate on additional features later. |
@adamjstewart I added denormalizing and switched to match statements, as well as made the other changes you suggested. I think it's ready now. |
Looks like the way I did the monkeypatch to test |
torchgeo/trainers/change.py
Outdated
keepdim=True, | ||
) | ||
batch = aug(batch) | ||
batch['prediction'] = y_hat.argmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be y_hat >= threshold
, not argmax? See SemanticSegmentationTask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right that it should be using the threshold, not argmax. But in SemanticSegmentationTask
(and here), shouldn't sigmoid be applied to y_hat
before thresholding? So it would be batch['prediction'] = (y_hat.sigmoid() >= 0.5).long()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right, good catch! Want to submit a separate PR to fix ClassificationTask and SemanticSegmentationTask? Then we can backport it to 0.7.1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Want to submit a separate PR to fix ClassificationTask and SemanticSegmentationTask?
I'll let someone else do that. But I've changed it here.
Very close to being ready to merge, excited to finally get this in!!! |
@adamjstewart Let me know if this looks ready now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work on this, so happy to finally merge!
@hkristen now that this is merged, you can take a stab at adding multiclass/multilabel support if you want. See #2219 for what this change looked like for SemanticSegmentationTask (code should be almost identical). Also see https://torchgeo.readthedocs.io/en/latest/user/contributing.html for general contributing guidelines. Happy to help with any questions you have or tests you encounter trouble with. |
You're welcome, thanks for your help, and glad it could be merged! |
This PR is to add a change detection trainer as mentioned in #2382.
Key points/items to discuss:
AugmentationSequential
doesn’t work, but can be combined withVideoSequential
to support the temporal dimension (see Kornia docs). I overrodeself.aug
in theOSCDDataModule
to do this but not sure if this should be incorporated into theBaseDataModule
instead.VideoSequential
adds a temporal dimension to the mask. Not sure if there is a way to avoid this, or if this is desirable, but I added an if statement to theAugmentationSequential
wrapper to check for and remove this added dimension._RandomNCrop
augmentation, but this does not work for time series data. I'm not sure how to modify_RandomNCrop
to fix this and would appreciate some help/guidance.cc @robmarkcole